"""
Phase 4: Is the weakening Z₁ × Resource interaction driven by oil prices or volumes?

Tests:
1. Rolling window interaction with oil price overlay
2. Normalize resource rents by oil price → does interaction stabilize?
3. Split by price regime (high/low oil) rather than time period
4. Add oil price as explicit control
5. Check volume (production) if available
"""

import sys, os
import pandas as pd
import numpy as np
import wbgapi as wb

sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')
from model import PanelGLS

OUT = '/mnt/c/demographics_capital_flows/eu_demographics/output/tables'

# ── Load panel ─────────────────────────────────────────────────────────
panel = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel_with_resources.csv')
panel = panel[panel['year'] <= 2024].copy()

controls = ['fiscal_bal_gdp', 'nfa_gdp_lag', 'rgdp_growth', 'log_rel_opw', 'kaopen']
base_vars = ['Z_1', 'Z_2', 'Z_3'] + controls

def run_model(name, dep, indep, data, quiet=False):
    est = data[['iso3', 'year', dep] + indep].dropna()
    if len(est) < 50:
        print(f"  {name}: insufficient obs ({len(est)})")
        return None
    m = PanelGLS()
    m.fit(est[dep].values, est[indep].values, est['iso3'].values, est['year'].values)
    if not quiet:
        print(f"\n  {name}: N={m.n_obs}, R²={m.r_squared:.4f}")
        for i, v in enumerate(indep):
            s = '***' if m.pvalues[i]<0.01 else '**' if m.pvalues[i]<0.05 else '*' if m.pvalues[i]<0.1 else ''
            print(f"    {v:40s} {m.beta[i]:10.4f} ({m.se[i]:.4f}){s}")
    return {'model': name, 'n_obs': m.n_obs, 'r_squared': m.r_squared,
            'vars': indep, 'beta': m.beta, 'se': m.se, 'pvalues': m.pvalues}

def st(p): return '***' if p<0.01 else '**' if p<0.05 else '*' if p<0.1 else ''

# ═══════════════════════════════════════════════════════════════════════
# STEP 1: Get oil price data
# ═══════════════════════════════════════════════════════════════════════
print("Step 1: Oil price data")
print("=" * 70)

# Brent crude annual averages (nominal USD/bbl) — well-known series
# Using hardcoded values from EIA/World Bank (reliable, widely cited)
oil_prices = {
    1970: 1.80, 1971: 2.24, 1972: 2.48, 1973: 3.29, 1974: 11.58,
    1975: 10.41, 1976: 11.63, 1977: 12.50, 1978: 12.70, 1979: 29.19,
    1980: 36.83, 1981: 35.93, 1982: 32.97, 1983: 29.55, 1984: 28.78,
    1985: 27.56, 1986: 14.43, 1987: 18.44, 1988: 14.92, 1989: 18.23,
    1990: 23.73, 1991: 20.00, 1992: 19.32, 1993: 16.97, 1994: 15.82,
    1995: 17.02, 1996: 20.67, 1997: 19.09, 1998: 12.72, 1999: 17.97,
    2000: 28.50, 2001: 24.44, 2002: 25.02, 2003: 28.83, 2004: 38.27,
    2005: 54.52, 2006: 65.14, 2007: 72.39, 2008: 96.94, 2009: 61.67,
    2010: 79.61, 2011: 111.26, 2012: 111.67, 2013: 108.66, 2014: 98.97,
    2015: 52.32, 2016: 43.73, 2017: 54.13, 2018: 71.34, 2019: 64.21,
    2020: 41.96, 2021: 70.68, 2022: 100.93, 2023: 82.62, 2024: 80.00,
}

oil_df = pd.DataFrame([{'year': y, 'oil_price': p} for y, p in oil_prices.items()])
oil_df['log_oil_price'] = np.log(oil_df['oil_price'])
oil_df['oil_price_real'] = oil_df['oil_price'] / oil_df['oil_price'].iloc[-1] * 80  # rough real

panel = panel.merge(oil_df[['year', 'oil_price', 'log_oil_price']], on='year', how='left')

print(f"Oil price merged. Range: {panel['oil_price'].min():.1f} to {panel['oil_price'].max():.1f}")

# ═══════════════════════════════════════════════════════════════════════
# STEP 2: Try to get production volume from WDI
# ═══════════════════════════════════════════════════════════════════════
print("\nStep 2: Oil production/fuel exports data")
print("=" * 70)

try:
    # Fuel exports as % of merchandise exports
    fuel_df = wb.data.DataFrame('TX.VAL.FUEL.ZS.UN', time=range(1970, 2025), labels=False)
    fuel_df = fuel_df.stack().reset_index()
    fuel_df.columns = ['iso3', 'year', 'fuel_exports_pct']
    fuel_df['year'] = fuel_df['year'].astype(str).str.replace('YR', '').astype(int)
    panel = panel.merge(fuel_df, on=['iso3', 'year'], how='left')
    print(f"  Fuel exports (% merch): {panel['fuel_exports_pct'].notna().sum()} obs")
except Exception as e:
    print(f"  Fuel exports download failed: {e}")

try:
    # Energy production (kt of oil equivalent)
    energy_df = wb.data.DataFrame('EG.EGY.PROD.KT.OE', time=range(1970, 2025), labels=False)
    energy_df = energy_df.stack().reset_index()
    energy_df.columns = ['iso3', 'year', 'energy_production_kt']
    energy_df['year'] = energy_df['year'].astype(str).str.replace('YR', '').astype(int)
    panel = panel.merge(energy_df, on=['iso3', 'year'], how='left')
    print(f"  Energy production (kt OE): {panel['energy_production_kt'].notna().sum()} obs")
except Exception as e:
    print(f"  Energy production download failed: {e}")

# ═══════════════════════════════════════════════════════════════════════
# STEP 3: Create price-adjusted resource rents
# ═══════════════════════════════════════════════════════════════════════
print("\nStep 3: Price-adjusted variables")
print("=" * 70)

# Resource rents already reflect prices (rents = (price - cost) × volume / GDP)
# If we want to strip out price, we can normalize by oil price
panel['rents_per_dollar_oil'] = panel['resource_rents_gdp'] / panel['oil_price'].clip(lower=1)
panel['Z1_x_rents_normalized'] = panel['Z_1'] * panel['rents_per_dollar_oil']

# Also create: rents × oil_price interaction to test if amplification is price-driven
panel['Z1_x_resource'] = panel['Z_1'] * panel['resource_rents_gdp']
panel['Z1_x_oil_price'] = panel['Z_1'] * panel['log_oil_price']
panel['resource_x_oil'] = panel['resource_rents_gdp'] * panel['log_oil_price']
panel['Z1_x_res_x_oil'] = panel['Z_1'] * panel['resource_rents_gdp'] * panel['log_oil_price']

# High/low oil price regimes
median_oil = panel['oil_price'].median()
panel['high_oil'] = (panel['oil_price'] > median_oil).astype(float)
panel['Z1_x_resource_x_high_oil'] = panel['Z1_x_resource'] * panel['high_oil']

print(f"  Median oil price: ${median_oil:.1f}")
print(f"  High oil obs: {panel['high_oil'].sum():.0f}, Low oil obs: {(1-panel['high_oil']).sum():.0f}")

# ═══════════════════════════════════════════════════════════════════════
# STEP 4: Core test — does adding oil price absorb the time variation?
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("STEP 4: Oil Price Controls")
print("=" * 70)

# Model A: Baseline interaction (reproduce)
print("\n--- A: Baseline Z₁ × Resource ---")
run_model('A: Z₁ × Resource', 'ca_gdp',
          base_vars + ['resource_rents_gdp', 'Z1_x_resource'], panel)

# Model B: Add log oil price
print("\n--- B: + Log Oil Price ---")
run_model('B: + log(oil)', 'ca_gdp',
          base_vars + ['resource_rents_gdp', 'Z1_x_resource', 'log_oil_price'], panel)

# Model C: Add Z₁ × oil price
print("\n--- C: + Z₁ × log(oil) ---")
run_model('C: + Z₁×oil', 'ca_gdp',
          base_vars + ['resource_rents_gdp', 'Z1_x_resource',
                       'log_oil_price', 'Z1_x_oil_price'], panel)

# Model D: Triple interaction Z₁ × Resource × Oil
print("\n--- D: Triple Z₁ × Resource × log(oil) ---")
run_model('D: triple', 'ca_gdp',
          base_vars + ['resource_rents_gdp', 'Z1_x_resource',
                       'log_oil_price', 'Z1_x_oil_price',
                       'resource_x_oil', 'Z1_x_res_x_oil'], panel)

# ═══════════════════════════════════════════════════════════════════════
# STEP 5: Price-normalized rents
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("STEP 5: Price-Normalized Resource Rents")
print("=" * 70)

# If we normalize rents by oil price, does the interaction stabilize over time?
print("\n--- Normalized rents (rents / oil_price) ---")
run_model('Normalized: Z₁ × (Rents/Oil)', 'ca_gdp',
          base_vars + ['rents_per_dollar_oil', 'Z1_x_rents_normalized'], panel)

# By period with normalized rents
for period, (y1, y2) in [('1980-1999', (1980, 1999)),
                          ('2000-2009', (2000, 2009)),
                          ('2010-2021', (2010, 2021))]:
    sub = panel[(panel['year'] >= y1) & (panel['year'] <= y2)]
    r = run_model(f'Norm {period}', 'ca_gdp',
                  base_vars + ['rents_per_dollar_oil', 'Z1_x_rents_normalized'], sub, quiet=True)
    if r:
        int_i = r['vars'].index('Z1_x_rents_normalized')
        rr_i = r['vars'].index('rents_per_dollar_oil')
        print(f"  {period}: N={r['n_obs']}, "
              f"Z₁×(Rents/Oil)={r['beta'][int_i]:.4f}{st(r['pvalues'][int_i])}, "
              f"Rents/Oil={r['beta'][rr_i]:.3f}{st(r['pvalues'][rr_i])}")

# ═══════════════════════════════════════════════════════════════════════
# STEP 6: High vs Low oil price regimes
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("STEP 6: High vs Low Oil Price Regimes")
print("=" * 70)

high_oil_panel = panel[panel['high_oil'] == 1]
low_oil_panel = panel[panel['high_oil'] == 0]

print(f"\nHigh oil (>${median_oil:.0f}/bbl):")
r_high = run_model('High oil: Z₁ × Resource', 'ca_gdp',
                   base_vars + ['resource_rents_gdp', 'Z1_x_resource'], high_oil_panel)

print(f"\nLow oil (<=${median_oil:.0f}/bbl):")
r_low = run_model('Low oil: Z₁ × Resource', 'ca_gdp',
                  base_vars + ['resource_rents_gdp', 'Z1_x_resource'], low_oil_panel)

# Also try terciles
p33 = panel['oil_price'].quantile(0.33)
p67 = panel['oil_price'].quantile(0.67)
for label, lo, hi in [('Low oil', 0, p33), ('Mid oil', p33, p67), ('High oil', p67, 200)]:
    sub = panel[(panel['oil_price'] >= lo) & (panel['oil_price'] < hi)]
    r = run_model(f'{label} tercile', 'ca_gdp',
                  base_vars + ['resource_rents_gdp', 'Z1_x_resource'], sub, quiet=True)
    if r:
        int_i = r['vars'].index('Z1_x_resource')
        rr_i = r['vars'].index('resource_rents_gdp')
        z1_i = r['vars'].index('Z_1')
        print(f"  {label} (${lo:.0f}-${hi:.0f}): N={r['n_obs']}, "
              f"Z₁={r['beta'][z1_i]:.3f}{st(r['pvalues'][z1_i])}, "
              f"Z₁×Res={r['beta'][int_i]:.4f}{st(r['pvalues'][int_i])}, "
              f"Rents={r['beta'][rr_i]:.3f}{st(r['pvalues'][rr_i])}")

# ═══════════════════════════════════════════════════════════════════════
# STEP 7: Rolling 10-year window
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("STEP 7: Rolling 10-Year Window — Z₁ × Resource Interaction")
print("=" * 70)

print(f"\n{'Window':15s} {'N':>5s} {'Z₁×Res':>10s} {'sig':>4s} {'Rents':>8s} {'sig':>4s} {'Avg Oil$':>9s}")
print("-" * 60)

for start in range(1980, 2013):
    end = start + 9
    sub = panel[(panel['year'] >= start) & (panel['year'] <= end)]
    avg_oil = sub['oil_price'].mean()

    r = run_model(f'{start}-{end}', 'ca_gdp',
                  base_vars + ['resource_rents_gdp', 'Z1_x_resource'], sub, quiet=True)
    if r:
        int_i = r['vars'].index('Z1_x_resource')
        rr_i = r['vars'].index('resource_rents_gdp')
        print(f"  {start}-{end}      {r['n_obs']:5d} {r['beta'][int_i]:10.4f} "
              f"{st(r['pvalues'][int_i]):>4s} {r['beta'][rr_i]:8.3f} "
              f"{st(r['pvalues'][rr_i]):>4s} {avg_oil:9.1f}")

# ═══════════════════════════════════════════════════════════════════════
# STEP 8: Fuel exports as volume proxy
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("STEP 8: Fuel Exports (Volume Proxy)")
print("=" * 70)

if 'fuel_exports_pct' in panel.columns and panel['fuel_exports_pct'].notna().sum() > 500:
    panel['Z1_x_fuel'] = panel['Z_1'] * panel['fuel_exports_pct']
    panel['fuel_high'] = (panel['fuel_exports_pct'] >= 50).astype(float)
    panel['Z1_x_fuel_high'] = panel['Z_1'] * panel['fuel_high']

    run_model('Fuel exports: Z₁ × fuel%', 'ca_gdp',
              base_vars + ['fuel_exports_pct', 'Z1_x_fuel'], panel)

    # By period
    for period, (y1, y2) in [('1980-1999', (1980, 1999)),
                              ('2000-2009', (2000, 2009)),
                              ('2010-2021', (2010, 2021))]:
        sub = panel[(panel['year'] >= y1) & (panel['year'] <= y2)]
        r = run_model(f'Fuel {period}', 'ca_gdp',
                      base_vars + ['fuel_exports_pct', 'Z1_x_fuel'], sub, quiet=True)
        if r:
            int_i = r['vars'].index('Z1_x_fuel')
            print(f"  {period}: Z₁×fuel={r['beta'][int_i]:.4f}{st(r['pvalues'][int_i])}")

if 'energy_production_kt' in panel.columns and panel['energy_production_kt'].notna().sum() > 500:
    panel['log_energy_prod'] = np.log(panel['energy_production_kt'].clip(lower=1))
    panel['Z1_x_energy_prod'] = panel['Z_1'] * panel['log_energy_prod']

    run_model('Energy prod: Z₁ × log(prod)', 'ca_gdp',
              base_vars + ['log_energy_prod', 'Z1_x_energy_prod'], panel)

# ═══════════════════════════════════════════════════════════════════════
# STEP 9: Decompose — price vs volume vs structural
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("STEP 9: Correlation — rolling interaction vs oil price")
print("=" * 70)

# Collect rolling results
from scipy import stats

roll_data = []
for start in range(1980, 2013):
    end = start + 9
    sub = panel[(panel['year'] >= start) & (panel['year'] <= end)]
    avg_oil = sub['oil_price'].mean()

    r = run_model(f'{start}-{end}', 'ca_gdp',
                  base_vars + ['resource_rents_gdp', 'Z1_x_resource'], sub, quiet=True)
    if r:
        int_i = r['vars'].index('Z1_x_resource')
        roll_data.append({
            'midpoint': start + 4.5,
            'interaction': r['beta'][int_i],
            'interaction_p': r['pvalues'][int_i],
            'avg_oil': avg_oil,
            'log_avg_oil': np.log(avg_oil),
            'n_obs': r['n_obs']
        })

rd = pd.DataFrame(roll_data)
if len(rd) > 5:
    r_price, p_price = stats.pearsonr(rd['avg_oil'], rd['interaction'])
    r_logprice, p_logprice = stats.pearsonr(rd['log_avg_oil'], rd['interaction'])
    r_time, p_time = stats.pearsonr(rd['midpoint'], rd['interaction'])

    print(f"\n  Correlation: rolling Z₁×Res interaction vs:")
    print(f"    Avg oil price:      r={r_price:.3f}, p={p_price:.4f}")
    print(f"    Log avg oil price:  r={r_logprice:.3f}, p={p_logprice:.4f}")
    print(f"    Time (midpoint):    r={r_time:.3f}, p={p_time:.4f}")

    # Partial correlation: interaction vs time, controlling for oil price
    # Residualize both on oil price
    slope_t, int_t, _, _, _ = stats.linregress(rd['avg_oil'], rd['midpoint'])
    slope_i, int_i, _, _, _ = stats.linregress(rd['avg_oil'], rd['interaction'])
    resid_time = rd['midpoint'] - (slope_t * rd['avg_oil'] + int_t)
    resid_interaction = rd['interaction'] - (slope_i * rd['avg_oil'] + int_i)
    r_partial, p_partial = stats.pearsonr(resid_time, resid_interaction)
    print(f"    Time | oil price:   r={r_partial:.3f}, p={p_partial:.4f} (partial)")

    rd.to_csv(os.path.join(OUT, 'rolling_interaction_vs_oil.csv'), index=False)
    print(f"\n  Saved: rolling_interaction_vs_oil.csv")

print("\n✓ Phase 4 complete.")
